import random
from collections import defaultdict

import numpy as np
import torch
from datasets import load_dataset
from transformers import DataCollatorForLanguageModeling, default_data_collator

from .Base import BaseDataset, UnlearnDataset


class TruthfulQADataset(BaseDataset):
    def __init__(self, dataset_name):
        self.dataset_name = dataset_name
        self.dataset = defaultdict()
        self.dataset = self.get_dataset()

    def get_dataset(self):
        dataset = load_dataset("truthful_qa", "generation", cache_dir="./.cache")[
            "validation"
        ]

        torch.manual_seed(8888)
        np.random.seed(8888)
        random.seed(8888)

        dataset = dataset.train_test_split(test_size=0.1)

        return dataset

    def __preprocess__(self, tokenizer):
        def preprocess(examples):
            results = {"input_ids": [], "attention_mask": [], "label": []}
            for i in range(len(examples["question"])):
                question = examples["question"][i]
                best_answer = examples["best_answer"][i]
                text = f"### Question: {question}\n ### Answer: {best_answer}"
                tokenized = tokenizer(
                    text,
                    truncation=True,
                    padding="max_length",
                    max_length=512,
                    add_special_tokens=True,
                )
                num_question_token = len(
                    tokenizer.tokenize(
                        f"### Question: {question}\n ", add_special_tokens=True
                    )
                )
                pad_length = 512 - len(tokenized.input_ids)
                pad_input_ids = (
                    tokenized.input_ids + [tokenizer.pad_token_id] * pad_length
                )
                pad_attention_mask = tokenized.attention_mask + [0] * pad_length
                if len(tokenized.input_ids) == 512:
                    label = tokenized.input_ids
                else:
                    label = tokenized.input_ids + [-100] * pad_length

                for i in range(num_question_token - 1):
                    label[i] = -100

                results["input_ids"].append(torch.tensor(pad_input_ids))
                results["attention_mask"].append(torch.tensor(pad_attention_mask))
                results["label"].append(torch.tensor(label))
            return results

        train_dataset = self.dataset["train"].map(
            preprocess,
            batched=True,
            remove_columns=[
                "correct_answers",
                "incorrect_answers",
                "source",
                "best_answer",
                "question",
                "category",
                "type",
            ],
        )

        train_dataset.set_format(
            type="torch", columns=["input_ids", "attention_mask", "label"]
        )
        self.dataset["train"] = train_dataset
        test_dataset = self.dataset["test"]

        test_dataset = test_dataset.map(
            preprocess,
            batched=True,
            remove_columns=[
                "correct_answers",
                "incorrect_answers",
                "source",
                "best_answer",
                "question",
                "category",
                "type",
            ],
        )
        self.dataset["test"] = test_dataset

    def build_dataset(self, tokenizer):
        self.__preprocess__(tokenizer)

        return self.dataset
